{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# softmax回归的简洁实现\n",
    "\n",
    "我们在[“线性回归的简洁实现”](linear-regression-gluon.ipynb)一节中已经了解了使用Gluon实现模型的便利。下面，让我们再次使用Gluon来实现一个softmax回归模型。首先导入所需的包或模块。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "1"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import d2lzh as d2l\n",
    "from mxnet import gluon, init\n",
    "from mxnet.gluon import loss as gloss, nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 获取和读取数据\n",
    "\n",
    "我们仍然使用Fashion-MNIST数据集和上一节中设置的批量大小。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义和初始化模型\n",
    "\n",
    "在[“softmax回归”](softmax-regression.ipynb)一节中提到，softmax回归的输出层是一个全连接层。因此，我们添加一个输出个数为10的全连接层。我们使用均值为0、标准差为0.01的正态分布随机初始化模型的权重参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nn.Dense(10))\n",
    "net.initialize(init.Normal(sigma=0.01))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## softmax和交叉熵损失函数\n",
    "\n",
    "如果做了上一节的练习，那么你可能意识到了分开定义softmax运算和交叉熵损失函数可能会造成数值不稳定。因此，Gluon提供了一个包括softmax运算和交叉熵损失计算的函数。它的数值稳定性更好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "4"
    }
   },
   "outputs": [],
   "source": [
    "loss = gloss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义优化算法\n",
    "\n",
    "我们使用学习率为0.1的小批量随机梯度下降作为优化算法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "5"
    }
   },
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "接下来，我们使用上一节中定义的训练函数来训练模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "6"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.7924, train acc 0.743, test acc 0.806\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, loss 0.5727, train acc 0.811, test acc 0.824\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, loss 0.5289, train acc 0.824, test acc 0.827\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, loss 0.5043, train acc 0.832, test acc 0.838\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, loss 0.4901, train acc 0.834, test acc 0.842\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 5\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None,\n",
    "              None, trainer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* Gluon提供的函数往往具有更好的数值稳定性。\n",
    "* 可以使用Gluon更简洁地实现softmax回归。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 尝试调一调超参数，如批量大小、迭代周期和学习率，看看结果会怎样。\n",
    "\n",
    "\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/740)\n",
    "\n",
    "![](../img/qr_softmax-regression-gluon.svg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}